import onnxruntime as ort
import numpy as np
from pathlib import Path
import os

class SimpleInfer:
    def __init__(self, model_path):
        self.ort_session = ort.InferenceSession(model_path)

    def infer(self, input_data):
        input_name = self.ort_session.get_inputs()[0].name
        output_name = self.ort_session.get_outputs()[0].name
        return self.ort_session.run([output_name], {input_name: input_data})[0]
    

if __name__ == '__main__':
    cur_path = os.path.dirname(os.path.abspath(__file__))
    print("cur_path: ", cur_path)
    model_path = os.path.join(cur_path, 'model/alpha-pose-136.onnx')
    print("model_path: ", model_path)

    input_data = np.random.rand(1, 3, 256, 192).astype(np.float32)
    simple_infer = SimpleInfer(model_path)
    output_data = simple_infer.infer(input_data)
    print(output_data.shape)